# -*-Python-*-
# MidiAutoencoder (base class)

import ddsp
import ddsp.training

include 'eval/midi_ae.gin'
include 'models/midiae/mixins/recon_lossgroup.gin'

# Training
train.num_steps = 500000


# ==============================================================================
# Model
# ==============================================================================
get_model.model = @models.MidiAutoencoder()


# Preprocessor
MidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
F0LoudnessPreprocessor.time_steps = 1000


# Synthcoder
MidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder()
DilatedConvDecoder.ch = 128
DilatedConvDecoder.layers_per_stack = 9
DilatedConvDecoder.norm_type = 'layer'
DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled')
DilatedConvDecoder.stacks = 2
DilatedConvDecoder.conditioning_keys = None
DilatedConvDecoder.precondition_stack = None
DilatedConvDecoder.output_splits = (('amplitudes', 1),
                                    ('harmonic_distribution', 60),
                                    ('magnitudes', 65))


# Stop Gradient
MidiAutoencoder.sg_before_midiae = True


# MIDI Encoder
MidiAutoencoder.midi_encoder = None


# MIDI Decoder
MidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder()
MidiToHarmonicDecoder.f0_residual = True
MidiToHarmonicDecoder.norm = True
MidiToHarmonicDecoder.output_splits = (('f0_midi', 1),
                                       ('amplitudes', 1),
                                       ('harmonic_distribution', 60),
                                       ('magnitudes', 65))
MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack()
dec/DilatedConvStack.ch = 128
dec/DilatedConvStack.layers_per_stack = 5
dec/DilatedConvStack.stacks = 4
dec/DilatedConvStack.norm_type = 'layer'
dec/DilatedConvStack.conditional = False


# ==============================================================================
# Losses
# ==============================================================================
MidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup()


